# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is based on
# the config parser from here: https://github.com/emre/storm/blob/master/storm/parsers/ssh_config_parser.py
# Copyright (C) <2013> <Emre Yilmaz>
# SPDX-License-Identifier: MIT

from __future__ import annotations
import os
import re
import traceback
from operator import itemgetter


PARAMIKO_IMPORT_ERROR: str | None
try:
    from paramiko.config import SSHConfig
except ImportError:
    SSHConfig = object  # type: ignore
    HAS_PARAMIKO = False
    PARAMIKO_IMPORT_ERROR = traceback.format_exc()
else:
    HAS_PARAMIKO = True
    PARAMIKO_IMPORT_ERROR = None


class StormConfig(SSHConfig):
    def parse(self, file_obj):
        """
        Read an OpenSSH config from the given file object.
        @param file_obj: a file-like object to read the config file from
        @type file_obj: file
        """
        order = 1
        host = {
            "host": ["*"],
            "config": {},
        }
        for line in file_obj:
            line = line.rstrip("\n").lstrip()
            if line == "":
                self._config.append(
                    {
                        "type": "empty_line",
                        "value": line,
                        "host": "",
                        "order": order,
                    }
                )
                order += 1
                continue

            if line.startswith("#"):
                self._config.append(
                    {
                        "type": "comment",
                        "value": line,
                        "host": "",
                        "order": order,
                    }
                )
                order += 1
                continue

            if "=" in line:
                # Ensure ProxyCommand gets properly split
                if line.lower().strip().startswith("proxycommand"):
                    proxy_re = re.compile(r"^(proxycommand)\s*=*\s*(.*)", re.I)
                    match = proxy_re.match(line)
                    key, value = match.group(1).lower(), match.group(2)
                else:
                    key, value = line.split("=", 1)
                    key = key.strip().lower()
            else:
                # find first whitespace, and split there
                i = 0
                while (i < len(line)) and not line[i].isspace():
                    i += 1
                if i == len(line):
                    raise Exception(f"Unparsable line: {line!r}")
                key = line[:i].lower()
                value = line[i:].lstrip()
            if key == "host":
                self._config.append(host)
                value = value.split()
                host = {key: value, "config": {}, "type": "entry", "order": order}
                order += 1
            elif key in ["identityfile", "localforward", "remoteforward"]:
                if key in host["config"]:
                    host["config"][key].append(value)
                else:
                    host["config"][key] = [value]
            elif key not in host["config"]:
                host["config"].update({key: value})
        self._config.append(host)


class ConfigParser:
    """
    Config parser for ~/.ssh/config files.
    """

    def __init__(self, ssh_config_file=None):
        if not ssh_config_file:
            ssh_config_file = self.get_default_ssh_config_file()

        self.defaults = {}

        self.ssh_config_file = ssh_config_file

        if not os.path.exists(self.ssh_config_file):
            if not os.path.exists(os.path.dirname(self.ssh_config_file)):
                os.makedirs(os.path.dirname(self.ssh_config_file))
            open(self.ssh_config_file, "w+").close()
            os.chmod(self.ssh_config_file, 0o600)

        self.config_data = []

    def get_default_ssh_config_file(self):
        return os.path.expanduser("~/.ssh/config")

    def load(self):
        config = StormConfig()

        with open(self.ssh_config_file) as fd:
            config.parse(fd)

        for entry in config.__dict__.get("_config"):
            if entry.get("host") == ["*"]:
                self.defaults.update(entry.get("config"))

            if entry.get("type") in ["comment", "empty_line"]:
                self.config_data.append(entry)
                continue

            host_item = {
                "host": entry["host"][0],
                "options": entry.get("config"),
                "type": "entry",
                "order": entry.get("order", 0),
            }

            if len(entry["host"]) > 1:
                host_item.update(
                    {
                        "host": " ".join(entry["host"]),
                    }
                )
            # minor bug in paramiko.SSHConfig that duplicates
            # "Host *" entries.
            if entry.get("config") and len(entry.get("config")) > 0:
                self.config_data.append(host_item)

        return self.config_data

    def add_host(self, host, options):
        self.config_data.append(
            {
                "host": host,
                "options": options,
                "order": self.get_last_index(),
            }
        )

        return self

    def update_host(self, host, options, use_regex=False):
        for index, host_entry in enumerate(self.config_data):
            if host_entry.get("host") == host or (use_regex and re.match(host, host_entry.get("host"))):
                if "deleted_fields" in options:
                    deleted_fields = options.pop("deleted_fields")
                    for deleted_field in deleted_fields:
                        del self.config_data[index]["options"][deleted_field]

                self.config_data[index]["options"].update(options)

        return self

    def search_host(self, search_string):
        results = []
        for host_entry in self.config_data:
            if host_entry.get("type") != "entry":
                continue
            if host_entry.get("host") == "*":
                continue

            searchable_information = host_entry.get("host")
            for value in host_entry.get("options").values():
                if isinstance(value, list):
                    value = " ".join(value)
                if isinstance(value, int):
                    value = str(value)

                searchable_information += f" {value}"

            if search_string in searchable_information:
                results.append(host_entry)

        return results

    def delete_host(self, host):
        found = 0
        for index, host_entry in enumerate(self.config_data):
            if host_entry.get("host") == host:
                del self.config_data[index]
                found += 1

        if found == 0:
            raise ValueError("No host found")
        return self

    def delete_all_hosts(self):
        self.config_data = []
        self.write_to_ssh_config()

        return self

    def dump(self):
        if len(self.config_data) < 1:
            return

        file_content = ""
        self.config_data = sorted(self.config_data, key=itemgetter("order"))

        for host_item in self.config_data:
            if host_item.get("type") in ["comment", "empty_line"]:
                file_content += f"{host_item.get('value')}\n"
                continue
            host_item_content = f"Host {host_item.get('host')}\n"
            for key, value in host_item.get("options").items():
                if isinstance(value, list):
                    sub_content = ""
                    for value_ in value:
                        sub_content += f"    {key} {value_}\n"
                    host_item_content += sub_content
                else:
                    host_item_content += f"    {key} {value}\n"
            file_content += host_item_content

        return file_content

    def write_to_ssh_config(self):
        with open(self.ssh_config_file, "w+") as f:
            data = self.dump()
            if data:
                f.write(data)
        return self

    def get_last_index(self):
        last_index = 0
        indexes = []
        for item in self.config_data:
            if item.get("order"):
                indexes.append(item.get("order"))
        if len(indexes) > 0:
            last_index = max(indexes)

        return last_index
